from diffusers import DDPMScheduler
import torch

class ScaledDDPMScheduler(DDPMScheduler):
    def __init__(self, factor=0.8, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._adjust_betas(factor)  

    def _adjust_betas(self, factor):
        self.betas = self.betas ** factor
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        
    def noise_sampling(self, x, timesteps=None):
        bs = x.shape[0]
        noise = torch.randn(x.shape, device=x.device)
        if timesteps == None:
            timesteps = torch.randint(0, self.config.num_train_timesteps, (bs,), device=x.device).long()
        samples = self.add_noise(x, noise, timesteps)
        return samples, timesteps